2024-06-14 19:08| 来源: 网络整理| 查看: 265




        以目标检测模型为例,YOLOv8如何从model = YOLO('yolov8n.yaml')加载模型。


from ultralytics import YOLO model = YOLO('yolov8n.yaml') model = YOLO('yolov8n.pt') model = YOLO('yolov8n.yaml').load('yolov8n.pt') results = model.train(data='coco128.yaml', epochs=100, imgsz=640)


class YOLO(Model): """YOLO (You Only Look Once) object detection model.""" def __init__(self, model="yolov8n.pt", task=None, verbose=False): """Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'.""" path = Path(model) #不会执行if, 而是执行else if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model new_instance = YOLOWorld(path) self.__class__ = type(new_instance) self.__dict__ = new_instance.__dict__ else: # Continue with default YOLO initialization super().__init__(model=model, task=task, verbose=verbose) @property def task_map(self): """Map head to model, trainer, validator, and predictor classes.""" return { "classify": { "model": ClassificationModel, "trainer": yolo.classify.ClassificationTrainer, "validator": yolo.classify.ClassificationValidator, "predictor": yolo.classify.ClassificationPredictor, }, "detect": { "model": DetectionModel, "trainer": yolo.detect.DetectionTrainer, "validator": yolo.detect.DetectionValidator, "predictor": yolo.detect.DetectionPredictor, }, "segment": { "model": SegmentationModel, "trainer": yolo.segment.SegmentationTrainer, "validator": yolo.segment.SegmentationValidator, "predictor": yolo.segment.SegmentationPredictor, }, "pose": { "model": PoseModel, "trainer": yolo.pose.PoseTrainer, "validator": yolo.pose.PoseValidator, "predictor": yolo.pose.PosePredictor, }, "obb": { "model": OBBModel, "trainer": yolo.obb.OBBTrainer, "validator": yolo.obb.OBBValidator, "predictor": yolo.obb.OBBPredictor, }, }

       对于YOLO类的__init__()函数中,首先判断 执行哪个程序段。

from pathlib import Path model = 'yolov8n.yaml' path = Path(model) print(path.stem) print(path.suffix) ''' path.stem:yolov8n path.suffix:.yaml '''


class Model(nn.Module): def __init__( self, model: Union[str, Path] = "yolov8n.pt", task: str = None, verbose: bool = False, ) -> None: super().__init__() self.callbacks = callbacks.get_default_callbacks() self.predictor = None # reuse predictor self.model = None # model object self.trainer = None # trainer object self.ckpt = None # if loaded from *.pt self.cfg = None # if loaded from *.yaml self.ckpt_path = None self.overrides = {} # overrides for trainer object self.metrics = None # validation/training metrics self.session = None # HUB session self.task = task # task type self.model_name = model = str(model).strip() # strip spaces # Check if Ultralytics HUB model from https://hub.ultralytics.com if self.is_hub_model(model): # Fetch model from HUB checks.check_requirements("hub-sdk>0.0.2") self.session = self._get_hub_session(model) model = self.session.model_file # Check if Triton Server model elif self.is_triton_model(model): self.model = model self.task = task return # Load or create new YOLO model model = checks.check_model_file_from_stem(model) # add suffix, i.e. yolov8n -> yolov8n.pt if Path(model).suffix in (".yaml", ".yml"): self._new(model, task=task, verbose=verbose) else: self._load(model, task=task) self.model_name = model


if self.is_hub_model(model):


def is_hub_model(model: str) -> bool: """Check if the provided model is a HUB model.""" return any( ( model.startswith(f"{HUB_WEB_ROOT}/models/"), # i.e. https://hub.ultralytics.com/models/MODEL_ID [len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODELID len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"), # MODELID ) )


HUB_WEB_ROOT = os.environ.get("ULTRALYTICS_HUB_WEB", "https://hub.ultralytics.com")

        因此,该函数返回值False,再检查elif self.is_triton_model(model),is_triton_model(model)函数如下:

def is_triton_model(model: str) -> bool: """Is model a Triton Server URL string, i.e. :////""" from urllib.parse import urlsplit url = urlsplit(model) return url.netloc and url.path and url.scheme in {"http", "grpc"}


from urllib.parse import urlsplit path = 'yolov8n.yaml' url = urlsplit(path) print(f'url:{url}') ''' url:SplitResult(scheme='', netloc='', path='yolov8n.yaml', query='', fragment='') '''


        接着执行model = checks.check_model_file_from_stem(model),checks.check_model_file_from_stem(model)函数如下:

def check_model_file_from_stem(model="yolov8n"): """Return a model filename from a valid model stem.""" if model and not Path(model).suffix and Path(model).stem in downloads.GITHUB_ASSETS_STEMS: return Path(model).with_suffix(".pt") # add suffix, i.e. yolov8n -> yolov8n.pt else: return model

        由于Path(model).suffix不为空,因此直接执行return model,注意此时的参数model仍然为'yolov8n.yaml'。


if Path(model).suffix in (".yaml", ".yml"):


self._new(model, task=task, verbose=verbose)


def _new(self, cfg: str, task=None, model=None, verbose=False) -> None: """ Initializes a new model and infers the task type from the model definitions. Args: cfg (str): model configuration file task (str | None): model task model (BaseModel): Customized model. verbose (bool): display model info on load """ cfg_dict = yaml_model_load(cfg) self.cfg = cfg self.task = task or guess_model_task(cfg_dict) self.model = (model or self._smart_load("model"))(cfg_dict, verbose=verbose and RANK == -1) # build model self.overrides["model"] = self.cfg self.overrides["task"] = self.task # Below added to allow export from YAMLs self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args) self.model.task = self.task


cfg_dict = yaml_model_load(cfg)


def yaml_model_load(path): """Load a YOLOv8 model from a YAML file.""" import re path = Path(path) if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)): new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem) LOGGER.warning(f"WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.") path = path.with_name(new_stem + path.suffix) unified_path = re.sub(r"(\d+)([nslmx])(.+)?$", r"\1\3", str(path)) # i.e. yolov8x.yaml -> yolov8.yaml yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path) d = yaml_load(yaml_file) # model dict d["scale"] = guess_model_scale(path) d["yaml_file"] = str(path) return d

        第一个if语句判断是False,因为 f''yolov{d}{x}6...''中有个6,因此判断为False,如果没有6则判断为True。


unified_path = re.sub(r"(\d+)([nslmx])(.+)?$", r"\1\3", str(path)) # i.e. yolov8x.yaml -> yolov8.yaml print(f'unified_path:{unified_path}') ''' unified_path:yolov8.yaml '''


yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path)


def check_yaml(file, suffix=(".yaml", ".yml"), hard=True): """Search/download YAML file (if necessary) and return path, checking suffix.""" return check_file(file, suffix, hard=hard)

        check_file() 函数如下:

def check_file(file, suffix="", download=True, hard=True): """Search/download file (if necessary) and return path.""" check_suffix(file, suffix) # optional file = str(file).strip() # convert to string and strip spaces file = check_yolov5u_filename(file) # yolov5n -> yolov5nu if ( not file or ("://" not in file and Path(file).exists()) # '://' check required in Windows Python :/ file = url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth if Path(file).exists(): LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists else: downloads.safe_download(url=url, file=file, unzip=False) return file else: # search files = glob.glob(str(ROOT / "cfg" / "**" / file), recursive=True) # find file if not files and hard: raise FileNotFoundError(f"'{file}' does not exist") elif len(files) > 1 and hard: raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}") return files[0] if len(files) else [] # return file



file = check_yolov5u_filename(file) # yolov5n -> yolov5nu


def check_yolov5u_filename(file: str, verbose: bool = True): """Replace legacy YOLOv5 filenames with updated YOLOv5u filenames.""" if "yolov3" in file or "yolov5" in file: if "u.yaml" in file: file = file.replace("u.yaml", ".yaml") # i.e. yolov5nu.yaml -> yolov5n.yaml elif ".pt" in file and "u" not in file: original_file = file file = re.sub(r"(.*yolov5([nsmlx]))\.pt", "\\1u.pt", file) # i.e. yolov5n.pt -> yolov5nu.pt file = re.sub(r"(.*yolov5([nsmlx])6)\.pt", "\\1u.pt", file) # i.e. yolov5n6.pt -> yolov5n6u.pt file = re.sub(r"(.*yolov3(|-tiny|-spp))\.pt", "\\1u.pt", file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt if file != original_file and verbose: LOGGER.info( f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are " f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs " f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n" ) return file

        该函数显然if语句为False,直接return file即可。


if ( not file or ("://" not in file and Path(file).exists()) # '://' check required in Windows Python :/ file = url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth if Path(file).exists(): LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists else: downloads.safe_download(url=url, file=file, unzip=False) return file

        由于file此时取值为'yolov8.yaml',因此not file、Path(file).exists()以及file.lower().startswith(''grpc://'')均为False,且elif语句中的 file.lower().startswith((''...''))因为False,最终,会直接运行如下语句:

else: # search files = glob.glob(str(ROOT / "cfg" / "**" / file), recursive=True) # find file if not files and hard: raise FileNotFoundError(f"'{file}' does not exist") elif len(files) > 1 and hard: raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}") return files[0] if len(files) else [] # return file


files = glob.glob(str(ROOT / "cfg" / "**" / file), recursive=True) # find file ''' ROOT来自于./ultralytics-main/ultralytics/utils/__init__.py,如下: FILE = Path(__file__).resolve() ROOT = FILE.parents[1] # YOLO '''



        可知,直接执行return files[0] if len(files) else[]。


def yaml_model_load(path): """Load a YOLOv8 model from a YAML file.""" import re path = Path(path) if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)): new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem) LOGGER.warning(f"WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.") path = path.with_name(new_stem + path.suffix) unified_path = re.sub(r"(\d+)([nslmx])(.+)?$", r"\1\3", str(path)) # i.e. yolov8x.yaml -> yolov8.yaml #在这里接收到路径 yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path) d = yaml_load(yaml_file) # model dict d["scale"] = guess_model_scale(path) d["yaml_file"] = str(path) return d

        其中,yaml_file用来接收yolov8.yaml的路径,d = yaml_load(yaml_file)则用来加载yolov8.yaml文件信息,d为字典形式。

        至此,则梳理明白了YOLOv8如何从model = YOLO('yolov8n.yaml')加载模型的原理。





